// The weave command is a simple preprocessor for markdown files.
// It builds a table of contents and processes %include directives.
//
// Example usage:
//
//	$ go run internal/cmd/weave go-types.md > README.md
//
// The weave command copies lines of the input file to standard output, with two
// exceptions:
//
// If a line begins with "%toc", it is replaced with a table of contents
// consisting of links to the top two levels of headers below the %toc symbol.
//
// If a line begins with "%include FILENAME TAG", it is replaced with the lines
// of the file between lines containing "!+TAG" and  "!-TAG". TAG can be omitted,
// in which case the delimiters are simply "!+" and "!-".
//
// Before the included lines, a line of the form
//
//	// go get PACKAGE
//
// is output, where PACKAGE is constructed from the module path, the
// base name of the current directory, and the directory of FILENAME.
// This caption can be suppressed by putting "-" as the final word of the %include line.
package main

import (
	"bufio"
	"bytes"
	"flag"
	"fmt"
	"io"
	"log"
	"os"
	"path/filepath"
	"regexp"
	"strings"
)

var output = flag.String("o", "", "output file (empty means stdout)")

func main() {
	flag.Usage = func() {
		fmt.Fprintf(flag.CommandLine.Output(), "usage: weave [flags] <input.md>\n\nflags:\n")
		flag.PrintDefaults()
	}
	flag.Parse()
	if flag.NArg() != 1 {
		flag.Usage()
		os.Exit(2)
	}

	log.SetFlags(0)
	log.SetPrefix("weave: ")

	wd, err := os.Getwd()
	if err != nil {
		log.Fatal(err)
	}
	curDir := filepath.Base(wd)

	in, err := os.Open(flag.Arg(0))
	if err != nil {
		log.Fatal(err)
	}
	defer in.Close()

	out := os.Stdout
	if *output != "" {
		out, err = os.Create(*output)
		if err != nil {
			log.Fatal(err)
		}
		defer func() {
			if err := out.Close(); err != nil {
				log.Fatal(err)
			}
		}()
	}

	printf := func(format string, args ...any) {
		if _, err := fmt.Fprintf(out, format, args...); err != nil {
			log.Fatalf("writing failed: %v", err)
		}
	}

	printf("<!-- Autogenerated by weave; DO NOT EDIT -->\n")

	// Pass 1: extract table of contents.
	type tocEntry struct {
		depth  int
		text   string
		anchor string
	}
	var (
		toc []tocEntry
		// We indent toc items according to their header depth, so that nested
		// headers result in nested lists. However, we want the lowest header depth
		// to correspond to the root of the list. Otherwise, the entire list is
		// indented, which turns it into a code block rather than an outline.
		minTocDepth int
	)
	scanner := bufio.NewScanner(in)
	for scanner.Scan() {
		line := scanner.Text()
		if line == "" || (line[0] != '#' && line[0] != '%') {
			continue
		}
		line = strings.TrimSpace(line)
		if line == "%toc" {
			toc = nil
			minTocDepth = 0
		} else if strings.HasPrefix(line, "# ") ||
			strings.HasPrefix(line, "## ") ||
			strings.HasPrefix(line, "### ") ||
			strings.HasPrefix(line, "#### ") {

			words := strings.Fields(line)
			depth := len(words[0])
			if minTocDepth == 0 || depth < minTocDepth {
				minTocDepth = depth
			}
			words = words[1:]
			text := strings.Join(words, " ")
			anchor := strings.Join(words, "-")
			anchor = strings.ToLower(anchor)
			anchor = strings.ReplaceAll(anchor, "**", "")
			anchor = strings.ReplaceAll(anchor, "`", "")
			anchor = strings.ReplaceAll(anchor, "_", "")
			toc = append(toc, tocEntry{depth: depth, text: text, anchor: anchor})
		}
	}
	if scanner.Err() != nil {
		log.Fatal(scanner.Err())
	}

	// Pass 2.
	if _, err := in.Seek(0, io.SeekStart); err != nil {
		log.Fatalf("can't rewind input: %v", err)
	}
	scanner = bufio.NewScanner(in)
	for scanner.Scan() {
		line := scanner.Text()
		switch {
		case strings.HasPrefix(line, "%toc"): // ToC
			for _, h := range toc {
				// Only print two levels of headings.
				if h.depth-minTocDepth <= 1 {
					printf("%s1. [%s](#%s)\n", strings.Repeat("\t", h.depth-minTocDepth), h.text, h.anchor)
				}
			}
		case strings.HasPrefix(line, "%include"):
			words := strings.Fields(line)
			var section string
			caption := true
			switch len(words) {
			case 2: // %include filename
			// Nothing to do.
			case 3: // %include filename section OR %include filename -
				if words[2] == "-" {
					caption = false
				} else {
					section = words[2]
				}
			case 4: // %include filename section -
				section = words[2]
				if words[3] != "-" {
					log.Fatalf("last word is not '-': %s", line)
				}
				caption = false
			default:
				log.Fatalf("wrong # words (want 2-4): %s", line)
			}
			filename := words[1]

			if caption {
				printf("	// go get golang.org/x/example/%s/%s\n\n",
					curDir, filepath.Dir(filename))
			}

			s, err := include(filename, section)
			if err != nil {
				log.Fatal(err)
			}
			printf("```go\n")
			printf("%s\n", cleanListing(s)) // TODO(adonovan): escape /^```/ in s
			printf("```\n")
		default:
			printf("%s\n", line)
		}
	}
	if scanner.Err() != nil {
		log.Fatal(scanner.Err())
	}
}

// include processes an included file, and returns the included text.
// Only lines between those matching !+tag and !-tag will be returned.
// This is true even if tag=="".
func include(file, tag string) (string, error) {
	f, err := os.Open(file)
	if err != nil {
		return "", err
	}
	defer f.Close()

	startre, err := regexp.Compile("!\\+" + tag + "$")
	if err != nil {
		return "", err
	}
	endre, err := regexp.Compile("!\\-" + tag + "$")
	if err != nil {
		return "", err
	}

	var text bytes.Buffer
	in := bufio.NewScanner(f)
	var on bool
	for in.Scan() {
		line := in.Text()
		switch {
		case startre.MatchString(line):
			on = true
		case endre.MatchString(line):
			on = false
		case on:
			text.WriteByte('\t')
			text.WriteString(line)
			text.WriteByte('\n')
		}
	}
	if in.Err() != nil {
		return "", in.Err()
	}
	if text.Len() == 0 {
		return "", fmt.Errorf("no lines of %s matched tag %q", file, tag)
	}
	return text.String(), nil
}

// cleanListing removes entirely blank leading and trailing lines from
// text, and removes n leading tabs.
func cleanListing(text string) string {
	lines := strings.Split(text, "\n")

	// remove minimum number of leading tabs from all non-blank lines
	tabs := 999
	for i, line := range lines {
		if strings.TrimSpace(line) == "" {
			lines[i] = ""
		} else {
			if n := leadingTabs(line); n < tabs {
				tabs = n
			}
		}
	}
	for i, line := range lines {
		if line != "" {
			line := line[tabs:]
			lines[i] = line // remove leading tabs
		}
	}

	// remove leading blank lines
	for len(lines) > 0 && lines[0] == "" {
		lines = lines[1:]
	}
	// remove trailing blank lines
	for len(lines) > 0 && lines[len(lines)-1] == "" {
		lines = lines[:len(lines)-1]
	}
	return strings.Join(lines, "\n")
}

func leadingTabs(s string) int {
	var i int
	for i = 0; i < len(s); i++ {
		if s[i] != '\t' {
			break
		}
	}
	return i
}
